classdef KrusellWithRobotsEachEconomy < handle
    % Combines production funtion, skill distribution, participation-cost
    % distribution, and utility function into an economy.
    
    properties
        pf       % production function
        dist     % distribution object
        util     % utility
        w_lb    % lower bound for wages
        w_ub    % upper bound for wages
        lb       % lower bound for parameters
        ub       % upper bound for parameters
        q_B      % price of robots
        q_E      % price of equipment capital
        q_S      % price of structures capital
        
        varnames % variables to solve for
        
        delta_B % depreciation rate for robots
        delta_E % depreciation rate for equipment
        delta_S % depreciation rate for structures
        
        solver
        
        fix_other_capital
        fix_all_but_robots % when computing equilibrium, only use
        % foc's for robots
        r % curvature parameter welfare function
        
        w % vector of wages, derived from integration nodes
        dw_dnodes % scaling factor due to change of variables in
        % integration
        
        phi_lb
        
        par_dist
        par_prod_fun
        revenue_requirement % exogenous revenue requirement, used
        % in resource constraint
        
        equipment_equal_robottax
    end
    
    properties (SetAccess = protected)
        num_nodes_f
        num_nodes_g
        num_nodes_bins
        num_nodes_phi
        
        nodes_f
        % pre-computed nodes for integration over :math:`f_i(w)`
        weights_f
        % pre-computed weights for integration over :math:`f_i(w)`
        
        nodes_g
        % pre-computed nodes for integration over :math:`g_i(w)`
        weights_g
        % pre-computed weights for integration over :math:`g_i(w)`
    end
    
    methods
        function self = KrusellWithRobotsEachEconomy(varargin)
            p = inputParser;
            p.KeepUnmatched = true;
            
            valid_num_nodes = @(x) isnumeric(x) && x==floor(x) && (x>0);
            
            p.addParameter('dist',[])
            p.addParameter('par_dist',[], @isstruct)
            p.addParameter('par_prod_fun',[], @isstruct)
            p.addParameter('pf',[])
            p.addParameter('util',[])
            p.addParameter('w_lb', [], @(x) x>0)
            p.addParameter('w_ub', [], @(x) x>0)
            p.addParameter('lb', [], @isnumeric)
            p.addParameter('ub', [], @isnumeric)
            p.addParameter('q_B', [], @(x) x>0)
            p.addParameter('q_E', [], @(x) x>0)
            p.addParameter('q_S', [], @(x) x>0)
            
            p.addParameter('delta_B', 0, @(x) x>=0 && x<=1)
            p.addParameter('delta_E', 0, @(x) x>=0 && x<=1)
            p.addParameter('delta_S', 0, @(x) x>=0 && x<=1)
            
            p.addParameter('solver', 'knitro', @ischar)
            p.addParameter('fix_other_capital', false, @islogical)
            p.addParameter('fix_all_but_robots', false, @islogical)
            p.addParameter('r', [], @isnumeric)
            
            p.addParameter('num_nodes_f', 10, valid_num_nodes)
            p.addParameter('num_nodes_g', 15, valid_num_nodes)
            p.addParameter('num_nodes_bins', 10, valid_num_nodes)
            p.addParameter('num_nodes_phi', 50, valid_num_nodes)
            p.addParameter('phi_lb', [], @isnumeric)
            p.addParameter('revenue_requirement', 0, @isnumeric)
            p.addParameter('equipment_equal_robottax', false, @islogical)
            p.addParameter('varnames', {'L_M','L_R','L_C', ...
                'K_B_M','K_B_R','K_B_C', 'K_E','K_S'}, @iscell)
            
            p.parse(varargin{:});
            
            self.pf = p.Results.pf;
            self.dist = p.Results.dist;
            self.par_dist = p.Results.par_dist;
            self.par_prod_fun = p.Results.par_prod_fun;
            self.util = p.Results.util;
            self.w_lb = p.Results.w_lb;
            self.w_ub = p.Results.w_ub;
            self.lb = p.Results.lb;
            self.ub = p.Results.ub;
            self.q_B = p.Results.q_B;
            self.q_E = p.Results.q_E;
            self.q_S = p.Results.q_S;
            
            self.delta_B = p.Results.delta_B;
            self.delta_E = p.Results.delta_E;
            self.delta_S = p.Results.delta_S;
            
            self.solver = p.Results.solver;
            self.fix_other_capital = p.Results.fix_other_capital;
            self.fix_all_but_robots = p.Results.fix_all_but_robots;
            self.r = p.Results.r;
            
            self.num_nodes_f = p.Results.num_nodes_f;
            self.num_nodes_g = p.Results.num_nodes_g;
            self.num_nodes_bins = p.Results.num_nodes_bins;
            self.num_nodes_phi = p.Results.num_nodes_phi;
            self.phi_lb = p.Results.phi_lb;
            self.revenue_requirement = p.Results.revenue_requirement;
            self.equipment_equal_robottax = p.Results.equipment_equal_robottax;
            
            self.varnames = p.Results.varnames;
            
            if isempty(self.w_lb) | isempty(self.w_ub)
                error('bounds for wages must be provided')
            end
            
            % update num_nodes for dist and dens_skill to ensure
            % they are in line
            self.dist.set_num_nodes(self.num_nodes_f, self.num_nodes_g, ...
                self.num_nodes_bins,self.num_nodes_phi);
            
            [self.nodes_f, self.weights_f] = ...
                tools.Integration.lgwt(self.num_nodes_f,-1,1);
            
            [self.nodes_g, self.weights_g] = ...
                tools.Integration.lgwt(self.num_nodes_g,-1,1);
            
            [self.w,self.dw_dnodes] = self.wages_from_nodes();
            
        end
        
        function set_num_nodes(self, num_nodes_f, num_nodes_g,num_nodes_bins,num_nodes_phi)
            % Set number of nodes properties for
            % :code:`KrusellWithRobotsEconomy`,
            % :code:`SkillAndParticipation`, and for :code:`dens_skill` objects
            %
            % Args:
            %     num_nodes_f (int): number of nodes to use for
            %         Gauss-Legendre integration over :math:`f()`
            %     num_nodes_g (int): number of nodes to use for
            %         Gauss-Legendre integration over :math:`g()`
            %     num_nodes_bins (double): number of nodes to use for
            %         Gauss-Legendre integration over bins of :math:`f()`
            
            
            valid_num_nodes = @(x) isnumeric(x) && ...
                x==floor(x) && (x>0);
            
            p = inputParser;
            p.addRequired('num_nodes_f', valid_num_nodes);
            p.addRequired('num_nodes_g', valid_num_nodes);
            p.addRequired('num_nodes_bins', valid_num_nodes);
            p.addRequired('num_nodes_phi', valid_num_nodes);
            
            p.parse(num_nodes_f, num_nodes_g, num_nodes_bins,num_nodes_phi);
            
            self.num_nodes_f = p.Results.num_nodes_f;
            self.num_nodes_g = p.Results.num_nodes_g;
            self.num_nodes_bins = p.Results.num_nodes_bins;
            self.num_nodes_phi = p.Results.num_nodes_phi;
            
            % re-compute nodes and weights
            [self.nodes_f, self.weights_f] = tools.Integration.lgwt(self.num_nodes_f,-1,1);
            [self.nodes_g, self.weights_g] = tools.Integration.lgwt(self.num_nodes_g,-1,1);
            [self.w,self.dw_dnodes] = self.wages_from_nodes();
            
            % set nodes and weights for dist object
            self.dist.set_num_nodes(num_nodes_f, num_nodes_g, ...
                num_nodes_bins,num_nodes_phi);
        end
        
        function [w,dw_dnodes] = wages_from_nodes(self)
            % transforms nodes used in numerical integration into the
            % corresponding wages
            w = tools.ChangeOfVariables.change_variable_minone_one_log(self.nodes_f,self.w_lb,self.w_ub);
            dw_dnodes = tools.ChangeOfVariables.change_variable_minone_one_log_deriv(self.nodes_f,self.w_lb,self.w_ub);
        end
        
        function out = welfare(self, var)
            out = self.welfare_participants(var) + self.welfare_non_participants(var);
        end
        
        function out = tot_welfare_weight_non_participants(self, var)
            out = self.Psi_prime(var.V_b) * (1-var.participation_rate);
        end
        
        function out = welfare_non_participants(self, var)
            out = self.Psi(var.V_b) * (1-var.participation_rate);
        end
        
        function out = welfare_participants(self, var)
            % returns scalar (already integrated over f_w)
            out = self.welfare_participants_M(var) + ...
                self.welfare_participants_R(var) + ...
                self.welfare_participants_C(var);
        end
        
        
        function [out, out_vector] = welfare_participants_M(self, var)
            % Args:
            %     var (struct): pre-computed variables
            % Returns:
            %     out (double): already integrated welfare
            %     out_vector (double): vector of welfare at w's
            %         (corresponding to integration nodes) before
            %         integration
            % Note:
            %     out_vector is integrated over participation costs at
            %     w, but not over w itself
            
                [out, out_vector] = welfare_computations_M(...
                    self,...
                    @(x) self.Psi(x),...
                    var);
            
        end
        
        function [out, out_vector] = tot_welfare_weight_participants_M(self, var)
            
                [out, out_vector] = welfare_computations_M(self,@(x) self.Psi_prime(x),var);
            
        end        
        
        function  [out, out_vector] = welfare_computations_M(self,fun,var)
        % fun is a function handle, either self.Psi or self.Psi_prime
            n_nodes = length(self.nodes_g);
            
            phi = @(xi,lb,ub) ...
                tools.ChangeOfVariables.change_variable_minone_one(xi,lb,ub);
            
            dphi_dxi = @(xi,lb,ub) ...
                tools.ChangeOfVariables.change_variable_minone_one_deriv(xi,lb,ub);
            
            fun = @(u,xi,lb,ub) fun(max(u - phi(xi,lb,ub),var.transfer)).*self.dist.dens_part.g_M(phi(xi,lb,ub),self.par_dist).* ...
                dphi_dxi(xi,lb,ub);
            
            nodes = self.nodes_g(:)';
            
            phi_tilde = var.phi_tilde_w;
            phi_tilde(phi_tilde < self.phi_lb) = self.phi_lb;
            
            u_mat = repmat(var.V_w(:),1,n_nodes);
            lb_mat = self.phi_lb.*ones(size(u_mat));
            ub_mat = repmat(phi_tilde(:),1,n_nodes);
            nodes_mat = repmat(nodes,length(phi_tilde(:)),1);
            
            fun_at_nodes = fun(u_mat,nodes_mat,lb_mat,ub_mat);
            
            out_vector = (fun_at_nodes * self.weights_g(:)).* ...
                var.f_M_w(:);
            try
                % should not work if called by GPOPSObject, since
                % there dimensions of w may be different than
                % specified in this object - but does not matter,
                % since this output is not used by GPOPSObject
                % (only out_vector is used)
                out = (out_vector .* self.dw_dnodes(:))' * ...
                    self.weights_f;
            catch
                out = [];
            end
            
    end
        
        function [out, out_vector] = welfare_participants_R(self, var)
            
            [out, out_vector] = welfare_computations_R(self,@(x) self.Psi(x),var);
            
        end

        function [out, out_vector] = tot_welfare_weight_participants_R(self, var)
            
            [out, out_vector] = welfare_computations_R(self,@(x) self.Psi_prime(x),var);
            
        end        
        
        function [out, out_vector] = welfare_computations_R(self,fun,var)
            n_nodes = length(self.nodes_g);
            
            phi = @(xi,lb,ub) ...
                tools.ChangeOfVariables.change_variable_minone_one(xi,lb,ub);
            
            dphi_dxi = @(xi,lb,ub) ...
                tools.ChangeOfVariables.change_variable_minone_one_deriv(xi,lb,ub);
            
            fun = @(u,xi,lb,ub) fun(max(u - phi(xi,lb,ub),var.transfer)).*self.dist.dens_part.g_R(phi(xi,lb,ub),self.par_dist).* ...
                dphi_dxi(xi,lb,ub);
            
            nodes = self.nodes_g(:)';
            
            phi_tilde = var.phi_tilde_w;
            phi_tilde(phi_tilde < self.phi_lb) = self.phi_lb;
            
            u_mat = repmat(var.V_w(:),1,n_nodes);
            lb_mat = self.phi_lb.*ones(size(u_mat));
            ub_mat = repmat(phi_tilde(:),1,n_nodes);
            nodes_mat = repmat(nodes,length(phi_tilde(:)),1);
            
            fun_at_nodes = fun(u_mat,nodes_mat,lb_mat,ub_mat);
            
            out_vector = (fun_at_nodes * self.weights_g(:)).* ...
                var.f_R_w(:);
            try
                % should not work if called by GPOPSObject, since
                % there dimensions of w may be different than
                % specified in this object - but does not matter,
                % since this output is not used by GPOPSObject
                % (only out_vector is used)
                out = (out_vector .* self.dw_dnodes(:))' * ...
                    self.weights_f;
            catch
                out = [];
            end
        end
        
        function [out, out_vector] = welfare_participants_C(self, var)
            
            [out, out_vector] = welfare_computations_C(self,@(x) self.Psi(x),var);
            
        end

        function [out, out_vector] = tot_welfare_weight_participants_C(self, var)
            
            [out, out_vector] = welfare_computations_C(self,@(x) self.Psi_prime(x),var);
            
        end        
        
        function [out, out_vector] = welfare_computations_C(self,fun,var)
            n_nodes = length(self.nodes_g);
            
            phi = @(xi,lb,ub) ...
                tools.ChangeOfVariables.change_variable_minone_one(xi,lb,ub);
            
            dphi_dxi = @(xi,lb,ub) ...
                tools.ChangeOfVariables.change_variable_minone_one_deriv(xi,lb,ub);
            
            fun = @(u,xi,lb,ub) fun(max(u - phi(xi,lb,ub),var.transfer)).*self.dist.dens_part.g_C(phi(xi,lb,ub),self.par_dist).* ...
                dphi_dxi(xi,lb,ub);
            
            nodes = self.nodes_g(:)';
            
            phi_tilde = var.phi_tilde_w;
            phi_tilde(phi_tilde < self.phi_lb) = self.phi_lb;
            
            u_mat = repmat(var.V_w(:),1,n_nodes);
            lb_mat = self.phi_lb.*ones(size(u_mat));
            ub_mat = repmat(phi_tilde(:),1,n_nodes);
            nodes_mat = repmat(nodes,length(phi_tilde(:)),1);
            
            fun_at_nodes = fun(u_mat,nodes_mat,lb_mat,ub_mat);
            
            out_vector = (fun_at_nodes * self.weights_g(:)).* ...
                var.f_C_w(:);
            try
                % should not work if called by GPOPSObject, since
                % there dimensions of w may be different than
                % specified in this object - but does not matter,
                % since this output is not used by GPOPSObject
                % (only out_vector is used)
                out = (out_vector .* self.dw_dnodes(:))' * ...
                    self.weights_f;
            catch
                out = [];
            end            
        end
        
        function out = Psi(self, V)
            % Social welfare function over indirect utilities
            out = (V.^(1-self.r)-1)./(1-self.r);
        end
        
        function out = Psi_prime(self, V)
            % Derivative of social welfare function
            out = V.^(-self.r);
        end
        
        function out = budget_constraint(self,var)
            out = var.tax_rev_inc - ...
                var.transfer * (1-var.participation_rate) + ...
                var.tau_B * self.q_B * var.K_B + ...
                var.tau_E * self.q_E * var.K_E + ...
                var.tau_S * self.q_S * var.K_S - ...
                self.revenue_requirement + var.compensation;
        end
        
        function [out, out_vector] = resource_constraint(self,var)
            rc_participants_vector = (var.l_w(:).*var.w(:) - ...
                var.c_w(:)).*var.h_w(:);
            
            rc_participants = (rc_participants_vector .* ...
                self.dw_dnodes(:))' * self.weights_f;
            
            rc_non_participants = -var.transfer * (1-var.participation_rate);
            
            rc_capital_expenditure =  (var.K_B * self.q_B + ...
                var.K_E * self.q_E + ...
                var.K_S * self.q_S);
            
            rc_capital_income = (var.K_B_M .* var.Y_B_M + ...
                var.K_B_R .* var.Y_B_R + ...
                var.K_B_C .* var.Y_B_C + ...
                var.K_E .* var.Y_E + ...
                var.K_S .* var.Y_S);
            
            tax_rev_B = var.K_B_M .* var.Y_B_M + var.K_B_R .* var.Y_B_R ...
                + var.K_B_C .* var.Y_B_C - ...
                var.K_B * self.q_B;
            
            tax_rev_E = var.K_E .* var.Y_E - var.K_E * self.q_E;
            
            tax_rev_S = var.K_S .* var.Y_S - var.K_S * self.q_S;
            
            rc_revenue_requirement = self.revenue_requirement;
            
            out = rc_participants + rc_non_participants + rc_capital_income ...
                - rc_capital_expenditure - rc_revenue_requirement;
        end
        
        
        function out = compute_elasticities(self,inp0,inp1,varargin)
            out.K_B_M_0 = inp0.K_B_M;
            out.K_B_R_0 = inp0.K_B_R;
            out.K_B_C_0 = inp0.K_B_C;
            
            out.K_B_0 = out.K_B_M_0 + out.K_B_R_0 + out.K_B_C_0;
            
            out.K_B_M_1 = inp1.K_B_M;
            out.K_B_R_1 = inp1.K_B_R;
            out.K_B_C_1 = inp1.K_B_C;
            
            out.K_B_1 = out.K_B_M_1 + out.K_B_R_1 + out.K_B_C_1;
            
            out.dK_B = out.K_B_1 - out.K_B_0;
            
            [out.Y_M_0, out.Y_R_0, out.Y_C_0, out.Y_B_M_0, out.Y_B_R_0, out.Y_B_C_0, out.Y_E_0, out.Y_S_0] = ...
                self.factor_prices(inp0,varargin{:});
            
            [out.Y_M_1, out.Y_R_1, out.Y_C_1, out.Y_B_M_1, out.Y_B_R_1, out.Y_B_C_1, out.Y_E_1, out.Y_S_1] = ...
                self.factor_prices(inp1,varargin{:});
            
            out.dY_M = out.Y_M_1 - out.Y_M_0;
            out.dY_R = out.Y_R_1 - out.Y_R_0;
            out.dY_C = out.Y_C_1 - out.Y_C_0;
            out.dY_B = out.Y_B_M_1 - out.Y_B_M_0;
            out.dY_E = out.Y_E_1 - out.Y_E_0;
            out.dY_S = out.Y_S_1 - out.Y_S_0;
            
            out.eps_Y_M_K_B = (out.dY_M/out.Y_M_0)/(out.dK_B/out.K_B_0);
            out.eps_Y_R_K_B = (out.dY_R/out.Y_R_0)/(out.dK_B/out.K_B_0);
            out.eps_Y_C_K_B = (out.dY_C/out.Y_C_0)/(out.dK_B/out.K_B_0);
            out.eps_Y_B_K_B = (out.dY_B/out.Y_B_M_0)/(out.dK_B/out.K_B_0);
            out.eps_Y_E_K_E = (out.dY_E/out.Y_E_0)/(out.dK_B/out.K_B_0);
            out.eps_Y_S_K_S = (out.dY_S/out.Y_S_0)/(out.dK_B/out.K_B_0);
            
        end
        
        function [sol,fval,exitflag,output] = ...
                compute_inp_elasticities(self,var0,lb, ub, dK_B,varargin)
            
            x0 = tools.translators.s2x(self,var0);
            
            Amat = [];
            b = [];
            Aeq = [];
            beq = [];
            
            
            obj = @(x) 1;
            nonlcon = @(x) self.nonlcon_elasticities(tools.translators.x2s(self,x),var0,dK_B,varargin{:});
            
            lb = tools.translators.s2x(self,lb);
            ub = tools.translators.s2x(self,ub);
            
            knitrooptions = optimset( 'Display','iter');
            options_file = ['../Economies/ComputeEquilibriumKnitro.opt'];
            
            [x,fval,exitflag,output] = knitromatlab(obj, x0, Amat,b,Aeq,beq,lb,ub, ...
                nonlcon,[],knitrooptions, ...
                options_file);
            
            
            sol = tools.translators.x2s(self,x);
        end
        
        
        function [sol,fval,exitflag,output] = ...
                compute_equilibrium(self,var0,lb, ub, varargin)
            % solve for equilibrium of the economy
            %
            % Args:
            %     var0 (struct): struct with initial parameters
            
            x0 = tools.translators.s2x(self,var0);
            
            switch self.solver
                case 'fmincon'
                    obj = @(x) 1;
                    nonlcon = @(x) self.nonlcon(tools.translators.x2s(self,x),varargin{:});
                    
                    Amat = [];
                    b = [];
                    Aeq = [];
                    beq = [];
                    lb = self.lb;
                    ub = self.ub;
                    
                    opt = optimoptions(@fmincon,'Algorithm','sqp','Display','iter',...
                        'AlwaysHonorConstraints','bounds','MaxFunctionEvaluations',10000,...
                        'StepTolerance',1e-6,'HonorBounds',true);
                    
                    
                    [x,fval,exitflag,output] = fmincon(obj,x0,Amat,b,Aeq, ...
                        beq,lb',ub',nonlcon,opt);
                    
                case 'knitro'
                    
                    obj = @(x) self.equilibrium_equations(tools.translators.x2s(self,x),varargin{:});
                    
                    lb = self.lb;
                    ub = self.ub;
                    
                    knitrooptions = optimset( 'Display','iter');
                    options_file = ['../Economies/ComputeEquilibriumKnitro.opt'];
                    
                    [x,fval,exitflag,output] = knitromatlab_lsqnonlin(obj, ...
                        x0, lb,ub,[],knitrooptions,options_file);
                    
                case 'knitro_eqsolve'
                    Amat = [];
                    b = [];
                    Aeq = [];
                    beq = [];
                    
                    
                    obj = @(x) 1;
                    nonlcon = @(x) self.nonlcon(tools.translators.x2s(self,x),varargin{:});
                    
                    lb = tools.translators.s2x(self,lb);
                    ub = tools.translators.s2x(self,ub);
                    
                    knitrooptions = optimset( 'Display','iter');
                    options_file = '../Economies/ComputeEquilibriumKnitro.opt';
                    
                    [x,fval,exitflag,output] = knitromatlab(obj, x0, Amat,b,Aeq,beq,lb,ub, ...
                        nonlcon,[],knitrooptions, ...
                        options_file);
            end
            
            sol = tools.translators.x2s(self,x);
        end
        
        
        function [sol,fval,exitflag,output] = compute_optimtax(self,var0,lb, ub, varargin)
            % compute optimal tax
            
            x0 = tools.translators.s2x(self,var0);
            
            Amat = [];
            b = [];
            Aeq = [];
            beq = [];
            
            obj = @(x) self.objective_optimtax(tools.translators.x2s(self,x),varargin{:});
            nonlcon = @(x) self.nonlcon_optimtax(tools.translators.x2s(self,x),varargin{:});
            
            lb = tools.translators.s2x(self,lb);
            ub = tools.translators.s2x(self,ub);
            
            knitrooptions = optimset( 'Display','iter');
            options_file = '../Economies/OptimTaxKnitro.opt';
            
            [x,fval,exitflag,output] = knitromatlab(obj, x0, Amat,b,Aeq,beq,lb,ub, ...
                nonlcon,[],knitrooptions, ...
                options_file);
            
            
            sol = tools.translators.x2s(self,x);
        end
        
        function [sol,fval,exitflag,output] = compute_compensation(self,var0,lb, ub, welfare_target,varargin)
            % compute optimal tax
            
            x0 = tools.translators.s2x(self,var0);
            
            Amat = [];
            b = [];
            Aeq = [];
            beq = [];
            
            obj = @(x) self.objective_compensation(tools.translators.x2s(self,x),varargin{:});
            nonlcon = @(x) self.nonlcon_compensation(tools.translators.x2s(self,x),welfare_target);
            
            lb = tools.translators.s2x(self,lb);
            ub = tools.translators.s2x(self,ub);
            
            knitrooptions = optimset( 'Display','iter');
            options_file = '../Economies/CompensationKnitro.opt';
            
            [x,fval,exitflag,output] = knitromatlab(obj, x0, Amat,b,Aeq,beq,lb,ub, ...
                nonlcon,[],knitrooptions, ...
                options_file);
            
            
            sol = tools.translators.x2s(self,x);
        end
        
        function out = objective_compensation(self, inp)
            out = inp.compensation;
        end
        
        function out = objective_optimtax(self, inp)
            vars = self.compute_vars_from_input(inp);
            out = -self.welfare(vars);
        end
        
        function [c, ceq] = nonlcon_compensation(self, inp, welfare_target)
            vars = self.compute_vars_from_input(inp);
            
            market_clearing = self.equilibrium_equations(vars);
            bc = self.budget_constraint(vars);
            c = [];
            welfare = - self.objective_optimtax(inp);
            ceq = [market_clearing, bc, welfare - welfare_target];
        end
        
        function [c, ceq] = nonlcon_optimtax(self, inp)
            vars = self.compute_vars_from_input(inp);
            
            market_clearing = self.equilibrium_equations(vars);
            bc = self.budget_constraint(vars);
            c = [];
            ceq = [market_clearing, bc];
            if self.equipment_equal_robottax
                ceq = [ceq,inp.tau_B - inp.tau_E];
            end
        end
        
        function [c, ceq] = nonlcon_elasticities(self,inp1,inp0,dK_B,varargin)
            [om_M, om_R, om_C, Y_B_M, Y_B_R, Y_B_C, Y_E, Y_S] = ...
                self.factor_prices(inp1,varargin{:});
            
            % inequality constraints
            c = [];
            
            % change in robots
            K_B_0 = inp0.K_B_M + inp0.K_B_R + inp0.K_B_C;
            K_B_1 = inp1.K_B_M + inp1.K_B_R + inp1.K_B_C;
            dK_B_condition = (K_B_1 - K_B_0) - dK_B;
            
            % equality constraints
            ceq = [Y_B_M - Y_B_R, Y_B_M - Y_B_C, dK_B_condition];
        end
        
        
        function [c, ceq] = nonlcon(self, inp, par)
            % inequality constraints
            c = [];
            
            % equality constraints
            vars = self.compute_vars_from_input(inp);
            ceq = self.equilibrium_equations(vars);
        end
        
        function out = compute_vars_from_input(self, inp)
            out = inp;
            out.w = self.w;
            out.dw_dnodes = self.dw_dnodes;
            
            xi = self.par_dist.xi;
            
            
            tau = inp.tau_hsv;
            lambda = inp.lambda_hsv;
            epsilon = self.util.epsilon;
            
            out.l_w = self.dist.l_hsv(out.w, epsilon, lambda, tau, xi);
            out.y = self.dist.y_hsv(out.w, epsilon, lambda, tau, ...
                xi);
            out.tax = self.dist.tax_hsv(out.w, epsilon, lambda, tau, xi);
            out.c_w = out.y - out.tax;
            
            
            out.V_w = self.util.utility(out.c_w, out.l_w, xi);
            out.V_b = self.util.utility(out.transfer, 0, xi);
            
            [out.Y_M,out.Y_R,out.Y_C,out.Y_B_M,out.Y_B_R,out.Y_B_C,out.Y_E,out.Y_S]=self.factor_prices(inp);
            
            out.Y = self.pf.Y(out);
            
            out.f_M_w = self.dist.dens_skill.f_M(out.w,out.Y_M,out.Y_R,out.Y_C,self.par_dist);
            out.f_R_w = self.dist.dens_skill.f_R(out.w,out.Y_M,out.Y_R,out.Y_C,self.par_dist);
            out.f_C_w = self.dist.dens_skill.f_C(out.w,out.Y_M,out.Y_R,out.Y_C,self.par_dist);
            
            out.f_w = out.f_M_w + out.f_R_w + out.f_C_w;
            
            out.phi_tilde_w = out.V_w - out.transfer;
            
            out.G_phi_tilde_M_w = ...
                self.dist.dens_part.G_M(out.phi_tilde_w, self.par_dist);
            
            out.G_phi_tilde_R_w = ...
                self.dist.dens_part.G_R(out.phi_tilde_w, self.par_dist);
            
            out.G_phi_tilde_C_w = ...
                self.dist.dens_part.G_C(out.phi_tilde_w, self.par_dist);
            
            out.K_B = out.K_B_M + out.K_B_R + out.K_B_C;
            
            out.q_B = self.q_B;
            out.q_E = self.q_E;
            out.q_S = self.q_S;
            
            % evaluate relevant functions at nodes
            out.h_M_w = self.dist.h_M(out.w, out.Y_M, out.Y_R, out.Y_C, self.par_dist);
            out.h_R_w = self.dist.h_R(out.w, out.Y_M, out.Y_R, out.Y_C, self.par_dist);
            out.h_C_w = self.dist.h_C(out.w, out.Y_M, out.Y_R, out.Y_C, self.par_dist);
            out.h_w = out.h_M_w + out.h_R_w + out.h_C_w;
            
            out.participation_rate = (out.h_w .* out.dw_dnodes)' * self.weights_f;
            
            % tax revenue
            out.total_lab_inc_using_y = (out.y(:) .*out.h_w(:).* out.dw_dnodes(:))' ...
                * self.weights_f(:);
            out.total_lab_inc_using_factors = out.L_M*out.Y_M + ...
                out.L_R*out.Y_R + out.L_C*out.Y_C;
            out.tax_rev_inc = (out.tax(:) .* out.h_w(:) .* out.dw_dnodes(:))' ...
                * self.weights_f(:);
            out.exp_transfer = (1-out.participation_rate) * ...
                out.transfer;
            out.tax_rev_B = out.K_B*out.q_B*out.tau_B;
            out.tax_rev_E = out.K_E*out.q_E*out.tau_E;
            out.tax_rev_S = out.K_S*out.q_S*out.tau_S;
            
            out.tax_rev_total = out.tax_rev_inc ...
                + out.tax_rev_B + out.tax_rev_E + out.tax_rev_S;
            
            out.excess_revenue = out.tax_rev_total - out.exp_transfer;
            
            out.budget_constraint = self.budget_constraint(out);
            
            if ~isfield(inp,'compensation')
                out.compensation = 0;
            end
        end
        
        
        function out = equilibrium_equations(self,inp,varargin)
            % Factor market clearing equations. Returns vector of
            % zeros in equilibrium.
            %
            % Args:
            %     inp (struct): with fields L_M, L_R, L_C, K_B_M,
            %         K_B_R, K_B_C, K_E, K_S
            %
            
            % factor prices
            [om_M, om_R, om_C, Y_B_M, Y_B_R, Y_B_C, Y_E, Y_S] = self.factor_prices(inp,varargin{:});
            
            % factor market clearing
            out(1) = self.L_M(om_M,om_R,om_C,inp)-inp.L_M;
            out(2) = self.L_R(om_M,om_R,om_C,inp)-inp.L_R;
            out(3) = self.L_C(om_M,om_R,om_C,inp)-inp.L_C;
            out(4) = self.foc_K_B(Y_B_M,inp.tau_B);
            out(5) = self.foc_K_B(Y_B_R,inp.tau_B);
            out(6) = self.foc_K_B(Y_B_C,inp.tau_B);
            
            if self.fix_other_capital
                out(7) = 0;
                out(8) = 0;
            else
                out(7) = self.foc_K_E(Y_E,inp.tau_E);
                out(8) = self.foc_K_S(Y_S,inp.tau_S);
            end
            
            if self.fix_all_but_robots
                out = zeros(size(out));
                out(4) = self.foc_K_B(Y_B_M,inp.tau_B);
                out(5) = self.foc_K_B(Y_B_R,inp.tau_B);
                out(6) = self.foc_K_B(Y_B_C,inp.tau_B);
            end
            
        end
        
        %% aggregate labor supplies
        
        function out = L_M(self,om_M,om_R,om_C,inp)
            integrand = @(w) w./om_M.*self.labor_supply(w,inp).*...
                self.dist.h_M(w,om_M,om_R,om_C,self.par_dist);
            out = tools.Integration.integral_legendre(integrand,self.w_lb,self.w_ub,self.num_nodes_f,self.nodes_f,self.weights_f);
        end
        
        function out = L_R(self,om_M,om_R,om_C,inp)
            integrand = @(w) w./om_R.*self.labor_supply(w,inp).*...
                self.dist.h_R(w,om_M,om_R,om_C,self.par_dist);
            out = tools.Integration.integral_legendre(integrand,self.w_lb,self.w_ub,self.num_nodes_f,self.nodes_f,self.weights_f);
        end
        
        function out = L_C(self,om_M,om_R,om_C,inp)
            integrand = @(w) w./om_C.*self.labor_supply(w,inp).*self.dist.h_C(w, ...
                om_M,om_R,om_C,self.par_dist);
            out = tools.Integration.integral_legendre(integrand,self.w_lb,self.w_ub,self.num_nodes_f,self.nodes_f,self.weights_f);
        end
        
        function out = labor_supply(self,w,inp)
            xi = self.par_dist.xi;
            
            epsilon = self.util.epsilon;
            lambda = inp.lambda_hsv;
            tau = inp.tau_hsv;
            out = self.dist.l_hsv(w, epsilon, lambda, tau, xi);
        end
        
        %% FOC's for capital
        
        function out = foc_K_B(self,Y_B,tau_B)
            out = (Y_B - self.delta_B) - (1+tau_B).*self.q_B;
        end
        
        function out = foc_K_E(self,Y_E,tau_E)
            out = (Y_E - self.delta_E) - (1+tau_E).*self.q_E;
        end
        
        function out = foc_K_S(self,Y_S,tau_S)
            out = (Y_S - self.delta_S) - (1+tau_S).*self.q_S;
        end
        
        %% statistics
        function out = participation_rate(self, inp, varargin)
            [om_M, om_R, om_C, ~, ~, ~, ~, ~] = ...
                self.factor_prices(inp, varargin{:});
            
            out = self.dist.mass_participants(om_M, om_R, om_C, ...
                self.w_lb, self.w_ub, ...
                varargin{:});
        end
        
        % mass of participants by occupation
        function out = part_M(self, inp, varargin)
            [om_M, om_R, om_C, ~, ~, ~, ~, ~] = ...
                self.factor_prices(inp,varargin{:});
            integrand = @(w) self.dist.h_M(w,om_M,om_R,om_C,varargin{:});
            out = tools.Integration.integral_legendre(integrand,self.w_lb,self.w_ub,self.num_nodes_f,self.nodes_f,self.weights_f);
        end
        
        function out = part_R(self, inp, varargin)
            [om_M, om_R, om_C, ~, ~, ~, ~, ~] = ...
                self.factor_prices(inp,varargin{:});
            integrand = @(w) self.dist.h_R(w,om_M,om_R,om_C,varargin{:});
            out = tools.Integration.integral_legendre(integrand,self.w_lb,self.w_ub,self.num_nodes_f,self.nodes_f,self.weights_f);
        end
        
        function out = part_C(self, inp, varargin)
            [om_M, om_R, om_C, ~, ~, ~, ~, ~] = ...
                self.factor_prices(inp,varargin{:});
            integrand = @(w) self.dist.h_C(w,om_M,om_R,om_C,varargin{:});
            out = tools.Integration.integral_legendre(integrand,self.w_lb,self.w_ub,self.num_nodes_f,self.nodes_f,self.weights_f);
        end
        
        function out = emp_M(self,inp,participation_rate,varargin)
            [om_M, om_R, om_C, ~, ~, ~, ~, ~] = ...
                self.factor_prices(inp,varargin{:});
            integrand = @(w) 1/participation_rate .* ...
                self.dist.h_M(w,om_M,om_R,om_C,varargin{:});
            out = tools.Integration.integral_legendre(integrand,self.w_lb,self.w_ub,self.num_nodes_f,self.nodes_f,self.weights_f);
        end
        
        function out = emp_R(self,inp,participation_rate,varargin)
            [om_M, om_R, om_C, ~, ~, ~, ~, ~] = ...
                self.factor_prices(inp,varargin{:});
            integrand = @(w) 1/participation_rate .* ...
                self.dist.h_R(w,om_M,om_R,om_C,varargin{:});
            out = tools.Integration.integral_legendre(integrand,self.w_lb,self.w_ub,self.num_nodes_f,self.nodes_f,self.weights_f);
        end
        
        function out = emp_C(self,inp,participation_rate,varargin)
            [om_M, om_R, om_C, ~, ~, ~, ~, ~] = ...
                self.factor_prices(inp,varargin{:});
            integrand = @(w) 1/participation_rate .* ...
                self.dist.h_C(w,om_M,om_R,om_C,varargin{:});
            out = tools.Integration.integral_legendre(integrand,self.w_lb,self.w_ub,self.num_nodes_f,self.nodes_f,self.weights_f);
        end
        
        function out = total_mass(self,inp,varargin)
            [om_M, om_R, om_C, ~, ~, ~, ~, ~] = ...
                self.factor_prices(inp,varargin{:});
            out = self.dist.dens_skill.total_mass(om_M,om_R,om_C,self.w_lb,self.w_ub,varargin{:});
        end
        
        % average wages, conditional on participation
        
        function out = average_wage(self,inp, participation_rate,varargin)
            [Y_M, Y_R, Y_C, ~, ~, ~, ~, ~] = self.factor_prices(inp,varargin{:});
            out = self.dist.average_wage(Y_M, Y_R, Y_C, self.w_lb, ...
                self.w_ub, participation_rate, varargin{:});
        end
        
        function out = average_wage_M(self,inp, participation_rate,varargin)
            [Y_M, Y_R, Y_C, ~, ~, ~, ~, ~] = self.factor_prices(inp,varargin{:});
            out = self.dist.average_wage_M(Y_M, Y_R, Y_C, self.w_lb, ...
                self.w_ub, participation_rate, varargin{:});
        end
        
        function out = average_wage_R(self,inp, participation_rate,varargin)
            [Y_M, Y_R, Y_C, ~, ~, ~, ~, ~] = self.factor_prices(inp,varargin{:});
            out = self.dist.average_wage_R(Y_M, Y_R, Y_C, self.w_lb, ...
                self.w_ub, participation_rate, varargin{:});
        end
        
        function out = average_wage_C(self,inp, participation_rate,varargin)
            [Y_M, Y_R, Y_C, ~, ~, ~, ~, ~] = self.factor_prices(inp,varargin{:});
            out = self.dist.average_wage_C(Y_M, Y_R, Y_C, self.w_lb, ...
                self.w_ub, participation_rate, varargin{:});
        end
        
        function out = labor_income_share(self, inp, varargin)
            labor_income = self.labor_income(inp, varargin{:});
            GDP = self.pf.Y(inp, varargin{:});
            out = labor_income./GDP;
        end
        
        function out = labor_income(self, inp, varargin)
            [om_M, om_R, om_C, ~, ~, ~, ~, ~] = ...
                self.factor_prices(inp,varargin{:});
            out = inp.L_M.*om_M + inp.L_R.*om_R + inp.L_C.*om_C;
        end
        
        %% helper functions
        
        function [om_M, om_R, om_C, Y_B_M, Y_B_R, Y_B_C, Y_E, Y_S] = factor_prices(self,inp, varargin)
            om_M = self.pf.Y_M(inp, varargin{:});
            om_R = self.pf.Y_R(inp, varargin{:});
            om_C = self.pf.Y_C(inp, varargin{:});
            Y_B_M = self.pf.Y_B_M(inp, varargin{:});
            Y_B_R = self.pf.Y_B_R(inp, varargin{:});
            Y_B_C = self.pf.Y_B_C(inp, varargin{:});
            Y_E = self.pf.Y_E(inp, varargin{:});
            Y_S = self.pf.Y_S(inp, varargin{:});
        end
        
    end
    
end